// Copyright Epic Games, Inc. All Rights Reserved.

#include <zencore/compositebuffer.h>

#include <zencore/sharedbuffer.h>
#include <zencore/testing.h>
#include <utility>

namespace zen {

const CompositeBuffer CompositeBuffer::Null;

void
CompositeBuffer::Reset()
{
	m_Segments.clear();
}

uint64_t
CompositeBuffer::GetSize() const
{
	uint64_t Accum = 0;

	for (const SharedBuffer& It : m_Segments)
	{
		Accum += It.GetSize();
	}

	return Accum;
}

bool
CompositeBuffer::IsOwned() const
{
	for (const SharedBuffer& It : m_Segments)
	{
		if (It.IsOwned() == false)
		{
			return false;
		}
	}
	return true;
}

CompositeBuffer
CompositeBuffer::MakeOwned() const&
{
	return CompositeBuffer(*this).MakeOwned();
}

CompositeBuffer
CompositeBuffer::MakeOwned() &&
{
	for (SharedBuffer& Segment : m_Segments)
	{
		Segment = std::move(Segment).MakeOwned();
	}
	return std::move(*this);
}

SharedBuffer
CompositeBuffer::ToShared() const&
{
	switch (m_Segments.size())
	{
		case 0:
			return SharedBuffer();
		case 1:
			return m_Segments[0];
		default:
			UniqueBuffer	  Buffer  = UniqueBuffer::Alloc(GetSize());
			MutableMemoryView OutView = Buffer.GetMutableView();

			for (const SharedBuffer& Segment : m_Segments)
			{
				OutView.CopyFrom(Segment.GetView());
				OutView += Segment.GetSize();
			}

			return Buffer.MoveToShared();
	}
}

SharedBuffer
CompositeBuffer::ToShared() &&
{
	return m_Segments.size() == 1 ? std::move(m_Segments[0]) : std::as_const(*this).Flatten();
}

CompositeBuffer
CompositeBuffer::Mid(uint64_t Offset, uint64_t Size) const
{
	const uint64_t BufferSize = GetSize();
	Offset					  = Min(Offset, BufferSize);
	Size					  = Min(Size, BufferSize - Offset);

	CompositeBuffer Buffer;
	{
		for (const SharedBuffer& Segment : m_Segments)
		{
			if (const uint64_t SegmentSize = Segment.GetSize(); Offset <= SegmentSize)
			{
				size_t PartSize = Min(Size, SegmentSize - Offset);
				if (PartSize == SegmentSize)
				{
					Buffer.m_Segments.push_back(Segment);
				}
				else if (PartSize > 0 || Size == 0)
				{
					// We need to add the segment even if PartSize is zero if we are picking up zero bytes.
					Buffer.m_Segments.push_back(SharedBuffer(IoBuffer(Segment.AsIoBuffer(), Offset, PartSize)));
				}
				Offset = 0;
				Size -= PartSize;
				if (Size == 0)
				{
					break;
				}
			}
			else
			{
				Offset -= SegmentSize;
			}
		}
	}
	return Buffer;
}

MemoryView
CompositeBuffer::ViewOrCopyRange(uint64_t									Offset,
								 uint64_t									Size,
								 UniqueBuffer&								CopyBuffer,
								 std::function<UniqueBuffer(uint64_t Size)> Allocator) const
{
	MemoryView View;
	IterateRange(
		Offset,
		Size,
		[Size, &View, &CopyBuffer, &Allocator, WriteView = MutableMemoryView()](MemoryView Segment, const SharedBuffer& ViewOuter) mutable {
			if (Segment.GetSize() == ViewOuter.GetSize())
			{
				// We assume that the segment of the buffer is kept in memory
				View = Segment;
			}
			else
			{
				if (WriteView.IsEmpty())
				{
					if (CopyBuffer.GetSize() < Size)
					{
						CopyBuffer = Allocator(Size);
					}
					View = WriteView = CopyBuffer.GetMutableView().Left(Size);
				}
				WriteView = WriteView.CopyFrom(Segment);
			}
		});
	return View;
}

CompositeBuffer::Iterator
CompositeBuffer::GetIterator(uint64_t Offset) const
{
	size_t SegmentCount = m_Segments.size();
	size_t SegmentIndex = 0;
	while (SegmentIndex < SegmentCount)
	{
		size_t SegmentSize = m_Segments[SegmentIndex].GetSize();
		if (Offset < SegmentSize)
		{
			return {.SegmentIndex = SegmentIndex, .OffsetInSegment = Offset};
		}
		Offset -= SegmentSize;
		SegmentIndex++;
	}
	return {.SegmentIndex = ~0ull, .OffsetInSegment = ~0ull};
}

MemoryView
CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const
{
	// We use a sub range IoBuffer when we want to copy data from a segment.
	// This means we will only materialize that range of the segment when doing
	// GetView() rather than the full segment.
	// A hot path for this code is when we call CompressedBuffer::FromCompressed which
	// is only interested in reading the header (first 64 bytes or so) and then throws
	// away the materialized data.
	if (CopyBuffer.GetSize() < Size)
	{
		CopyBuffer = UniqueBuffer::Alloc(Size);
	}
	MutableMemoryView WriteView	   = CopyBuffer.GetMutableView();
	size_t			  SegmentCount = m_Segments.size();
	ZEN_ASSERT(It.SegmentIndex < SegmentCount);
	uint64_t SizeLeft = Size;
	while (SizeLeft > 0 && It.SegmentIndex < SegmentCount)
	{
		const SharedBuffer& Segment		= m_Segments[It.SegmentIndex];
		size_t				SegmentSize = Segment.GetSize();
		size_t				CopySize	= zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft);
		IoBuffer			SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
		MemoryView			ReadView = SubSegment.GetView();
		WriteView					 = WriteView.CopyFrom(ReadView);
		It.OffsetInSegment += CopySize;
		ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
		if (It.OffsetInSegment == SegmentSize)
		{
			It.SegmentIndex++;
			It.OffsetInSegment = 0;
		}
		SizeLeft -= CopySize;
	}
	return CopyBuffer.GetView().Left(Size - SizeLeft);
}

void
CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const
{
	// We use a sub range IoBuffer when we want to copy data from a segment.
	// This means we will only materialize that range of the segment when doing
	// GetView() rather than the full segment.
	// A hot path for this code is when we call CompressedBuffer::FromCompressed which
	// is only interested in reading the header (first 64 bytes or so) and then throws
	// away the materialized data.

	size_t SizeLeft		= WriteView.GetSize();
	size_t SegmentCount = m_Segments.size();
	ZEN_ASSERT(It.SegmentIndex < SegmentCount);
	while (WriteView.GetSize() > 0 && It.SegmentIndex < SegmentCount)
	{
		const SharedBuffer& Segment		= m_Segments[It.SegmentIndex];
		size_t				SegmentSize = Segment.GetSize();
		size_t				CopySize	= zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft);
		IoBuffer			SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
		MemoryView			ReadView = SubSegment.GetView();
		WriteView					 = WriteView.CopyFrom(ReadView);
		It.OffsetInSegment += CopySize;
		ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
		if (It.OffsetInSegment == SegmentSize)
		{
			It.SegmentIndex++;
			It.OffsetInSegment = 0;
		}
		SizeLeft -= CopySize;
	}
}

void
CompositeBuffer::CopyTo(MutableMemoryView Target, uint64_t Offset) const
{
	IterateRange(Offset, Target.GetSize(), [Target](MemoryView View, [[maybe_unused]] const SharedBuffer& ViewOuter) mutable {
		Target = Target.CopyFrom(View);
	});
}

void
CompositeBuffer::IterateRange(uint64_t Offset, uint64_t Size, std::function<void(MemoryView View)> Visitor) const
{
	IterateRange(Offset, Size, [Visitor](MemoryView View, [[maybe_unused]] const SharedBuffer& ViewOuter) { Visitor(View); });
}

void
CompositeBuffer::IterateRange(uint64_t															  Offset,
							  uint64_t															  Size,
							  std::function<void(MemoryView View, const SharedBuffer& ViewOuter)> Visitor) const
{
	ZEN_ASSERT(Offset + Size <= GetSize());
	for (const SharedBuffer& Segment : m_Segments)
	{
		const uint64_t SegmentSize = Segment.GetSize();
		if (Size == 0 && Offset == SegmentSize)
		{
			// Special case for getting the zero size end of a composite buffer
			const MemoryView View = Segment.GetView().Mid(Offset, 0);
			Visitor(View, Segment);
			break;
		}
		if (Offset < SegmentSize)
		{
			if (Offset == 0 && Size >= SegmentSize)
			{
				const MemoryView View = Segment.GetView();
				if (!View.IsEmpty())
				{
					Visitor(View, Segment);
				}
				Size -= View.GetSize();
				if (Size == 0)
				{
					break;
				}
			}
			else
			{
				// If we only want a section of the segment, do a subrange so we don't have to materialize the entire iobuffer
				IoBuffer		 SubRange(Segment.AsIoBuffer(), Offset, Min(Size, SegmentSize - Offset));
				const MemoryView View = SubRange.GetView();
				if (!View.IsEmpty())
				{
					Visitor(View, Segment);
				}
				Size -= View.GetSize();
				if (Size == 0)
				{
					break;
				}
				Offset = 0;
			}
		}
		else
		{
			Offset -= SegmentSize;
		}
	}
}

#if ZEN_WITH_TESTS
TEST_CASE("CompositeBuffer Null")
{
	CompositeBuffer Buffer;
	CHECK(Buffer.IsNull());
	CHECK(Buffer.IsOwned());
	CHECK(Buffer.MakeOwned().IsNull());
	CHECK(Buffer.Flatten().IsNull());
	CHECK(Buffer.Mid(0, 0).IsNull());
	CHECK(Buffer.GetSize() == 0);
	CHECK(Buffer.GetSegments().size() == 0);

	UniqueBuffer CopyBuffer;
	CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer).IsEmpty());
	CHECK(CopyBuffer.IsNull());

	MutableMemoryView CopyView;
	Buffer.CopyTo(CopyView);

	uint32_t VisitCount = 0;
	Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; });
	CHECK(VisitCount == 0);
}

TEST_CASE("CompositeBuffer Empty")
{
	const uint8_t	   EmptyArray[]{0};
	const SharedBuffer EmptyView = SharedBuffer::MakeView(EmptyArray, 0);
	CompositeBuffer	   Buffer(EmptyView);
	CHECK(Buffer.IsNull() == false);
	CHECK(Buffer.IsOwned() == false);
	CHECK(Buffer.MakeOwned().IsNull() == false);
	CHECK(Buffer.MakeOwned().IsOwned() == true);
	CHECK(Buffer.Flatten() == EmptyView);
	CHECK(Buffer.Mid(0, 0).Flatten() == EmptyView);
	CHECK(Buffer.GetSize() == 0);
	CHECK(Buffer.GetSegments().size() == 1);
	CHECK(Buffer.GetSegments()[0] == EmptyView);

	UniqueBuffer CopyBuffer;
	CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer) == EmptyView.GetView());
	CHECK(CopyBuffer.IsNull());

	MutableMemoryView CopyView;
	Buffer.CopyTo(CopyView);

	uint32_t VisitCount = 0;
	Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; });
	CHECK(VisitCount == 1);
}

TEST_CASE("CompositeBuffer Empty[1]")
{
	const uint8_t	   EmptyArray[1]{};
	const SharedBuffer EmptyView1 = SharedBuffer::MakeView(EmptyArray, 0);
	const SharedBuffer EmptyView2 = SharedBuffer::MakeView(EmptyArray + 1, 0);
	CompositeBuffer	   Buffer(EmptyView1, EmptyView2);
	CHECK(Buffer.Mid(0, 0).Flatten() == EmptyView1);
	CHECK(Buffer.GetSize() == 0);
	CHECK(Buffer.GetSegments().size() == 2);
	CHECK(Buffer.GetSegments()[0] == EmptyView1);
	CHECK(Buffer.GetSegments()[1] == EmptyView2);

	UniqueBuffer CopyBuffer;
	CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer) == EmptyView1.GetView());
	CHECK(CopyBuffer.IsNull());

	MutableMemoryView CopyView;
	Buffer.CopyTo(CopyView);

	uint32_t VisitCount = 0;
	Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; });
	CHECK(VisitCount == 1);
}

TEST_CASE("CompositeBuffer Flat")
{
	const uint8_t	   FlatArray[]{1, 2, 3, 4, 5, 6, 7, 8};
	const SharedBuffer FlatView = SharedBuffer::Clone(MakeMemoryView(FlatArray));
	CompositeBuffer	   Buffer(FlatView);

	CHECK(Buffer.IsNull() == false);
	CHECK(Buffer.IsOwned() == true);
	CHECK(Buffer.Flatten() == FlatView);
	CHECK(Buffer.MakeOwned().Flatten() == FlatView);
	CHECK(Buffer.Mid(0).Flatten() == FlatView);
	CHECK(Buffer.Mid(4).Flatten().GetView() == FlatView.GetView().Mid(4));
	CHECK(Buffer.Mid(8).Flatten().GetView() == FlatView.GetView().Mid(8));
	CHECK(Buffer.Mid(4, 2).Flatten().GetView() == FlatView.GetView().Mid(4, 2));
	CHECK(Buffer.Mid(8, 0).Flatten().GetView() == FlatView.GetView().Mid(8, 0));
	CHECK(Buffer.GetSize() == sizeof(FlatArray));
	CHECK(Buffer.GetSegments().size() == 1);
	CHECK(Buffer.GetSegments()[0] == FlatView);

	UniqueBuffer CopyBuffer;
	CHECK(Buffer.ViewOrCopyRange(0, sizeof(FlatArray), CopyBuffer) == FlatView.GetView());
	CHECK(CopyBuffer.IsNull());

	uint8_t CopyArray[sizeof(FlatArray) - 3];
	Buffer.CopyTo(MakeMutableMemoryView(CopyArray), 3);
	CHECK(MakeMemoryView(CopyArray).EqualBytes(MakeMemoryView(FlatArray) + 3));

	uint32_t VisitCount = 0;
	Buffer.IterateRange(0, sizeof(FlatArray), [&VisitCount](MemoryView) { ++VisitCount; });
	CHECK(VisitCount == 1);
}

TEST_CASE("CompositeBuffer Composite")
{
	const uint8_t	   FlatArray[]{1, 2, 3, 4, 5, 6, 7, 8};
	const SharedBuffer FlatView1 = SharedBuffer::MakeView(MakeMemoryView(FlatArray).Left(4));
	const SharedBuffer FlatView2 = SharedBuffer::MakeView(MakeMemoryView(FlatArray).Right(4));
	CompositeBuffer	   Buffer(FlatView1, FlatView2);

	CHECK(Buffer.IsNull() == false);
	CHECK(Buffer.IsOwned() == false);
	CHECK(Buffer.Flatten().GetView().EqualBytes(MakeMemoryView(FlatArray)));
	CHECK(Buffer.Mid(2, 4).Flatten().GetView().EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4)));
	CHECK(Buffer.Mid(0, 4).Flatten() == FlatView1);
	CHECK(Buffer.Mid(4, 4).Flatten() == FlatView2);
	CHECK(Buffer.GetSize() == sizeof(FlatArray));
	CHECK(Buffer.GetSegments().size() == 2);
	CHECK(Buffer.GetSegments()[0] == FlatView1);
	CHECK(Buffer.GetSegments()[1] == FlatView2);

	UniqueBuffer CopyBuffer;

	CHECK(Buffer.ViewOrCopyRange(0, 4, CopyBuffer) == FlatView1.GetView());
	CHECK(CopyBuffer.IsNull() == true);
	CHECK(Buffer.ViewOrCopyRange(4, 4, CopyBuffer) == FlatView2.GetView());
	CHECK(CopyBuffer.IsNull() == true);
	CHECK(Buffer.ViewOrCopyRange(3, 2, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(3, 2)));
	CHECK(CopyBuffer.GetSize() == 2);
	CHECK(Buffer.ViewOrCopyRange(1, 6, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(1, 6)));
	CHECK(CopyBuffer.GetSize() == 6);
	CHECK(Buffer.ViewOrCopyRange(2, 4, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4)));
	CHECK(CopyBuffer.GetSize() == 6);

	uint8_t CopyArray[4];
	Buffer.CopyTo(MakeMutableMemoryView(CopyArray), 2);
	CHECK(MakeMemoryView(CopyArray).EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4)));

	uint32_t VisitCount = 0;
	Buffer.IterateRange(0, sizeof(FlatArray), [&VisitCount](MemoryView) { ++VisitCount; });
	CHECK(VisitCount == 2);

	const auto TestIterateRange =
		[&Buffer](uint64_t Offset, uint64_t Size, MemoryView ExpectedView, const SharedBuffer& ExpectedViewOuter) {
			uint32_t	 VisitCount = 0;
			MemoryView	 ActualView;
			SharedBuffer ActualViewOuter;
			Buffer.IterateRange(Offset, Size, [&VisitCount, &ActualView, &ActualViewOuter](MemoryView View, const SharedBuffer& ViewOuter) {
				++VisitCount;
				ActualView		= View;
				ActualViewOuter = ViewOuter;
			});
			CHECK(VisitCount == 1);
			CHECK(ActualView == ExpectedView);
			CHECK(ActualViewOuter == ExpectedViewOuter);
		};
	TestIterateRange(0, 4, MakeMemoryView(FlatArray).Mid(0, 4), FlatView1);
	TestIterateRange(4, 0, MakeMemoryView(FlatArray).Mid(4, 0), FlatView1);
	TestIterateRange(4, 4, MakeMemoryView(FlatArray).Mid(4, 4), FlatView2);
	TestIterateRange(8, 0, MakeMemoryView(FlatArray).Mid(8, 0), FlatView2);
}

void
compositebuffer_forcelink()
{
}
#endif

}  // namespace zen
